import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split, KFold
from sklearn.metrics import confusion_matrix
from PIL import Image
import math
import seaborn as sns
import numpy as np
import os
import cv2
import shutil
import pandas as pd
import random
import time
# silence warnings
import warnings
warnings.filterwarnings('ignore')
# setup for multiple outputs from single cell
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = 'all'
The goal of this project is to construct an image classification system using a PyTorch neural network to classify nine common mushroom species. Image classification trainng data sets were extracted from video clips. Nine species of mushrooms were selected for their unique appearances.
# Display the nine mushroom species
# This script's ("PyTorch_Mushroom_Image_Classification.ipynb") directory to provide
# relative location of folder ('Pics') holding pictures
# Adjust this line to reflect any new location
location_of_this_ipynb_file = '/media/ijmg/SSD_FOUR_TB/ACADEMICS_101/MY_PROJECTS/ADDED_PROJECTS/Fungi/'
# Paths to images (in relative folder 'Pics') and their associated labels
path_to_amanita_muscaria_pic = os.path.join(location_of_this_ipynb_file, 'Pics/amanita_muscaria.jpeg')
path_to_calocera_viscosa_pic = os.path.join(location_of_this_ipynb_file, 'Pics/calocera_viscosa.jpeg')
path_to_clathrus_ruber_pic = os.path.join(location_of_this_ipynb_file, 'Pics/clathrus_ruber.jpeg')
path_to_coprinus_comatus_pic = os.path.join(location_of_this_ipynb_file, 'Pics/coprinus_comatus.jpeg')
path_to_favolaschia_calocera_pic = os.path.join(location_of_this_ipynb_file, 'Pics/favolaschia_calocera.jpeg')
path_to_ganoderma_lucidum_pic = os.path.join(location_of_this_ipynb_file, 'Pics/ganoderma_lucidum.jpeg')
path_to_laetiporus_sulphureus_pic = os.path.join(location_of_this_ipynb_file, 'Pics/laetiporus_sulphureus.jpeg')
path_to_morchella_esculenta_pic = os.path.join(location_of_this_ipynb_file, 'Pics/morchella_esculenta.jpeg')
path_to_phallus_indusiatus_pic = os.path.join(location_of_this_ipynb_file, 'Pics/phallus_indusiatus.jpeg')
image_paths = [path_to_amanita_muscaria_pic, path_to_calocera_viscosa_pic,
path_to_clathrus_ruber_pic, path_to_coprinus_comatus_pic,
path_to_favolaschia_calocera_pic, path_to_ganoderma_lucidum_pic,
path_to_laetiporus_sulphureus_pic, path_to_morchella_esculenta_pic,
path_to_phallus_indusiatus_pic
]
labels = ['1-- Amanita muscaria \n Common Name: "Fly Agaric Mushroom"',
'2-- Calocera viscosa \n Common Name: "Yellow Staghorn Mushroom"',
'3-- Clathrus ruber \n Common Name: "Red Cage Lattice Stinkhorn"',
'4-- Coprinus comatus \n Common Name: "Shaggy Ink Cap Mushroom"',
'5-- Favolaschia calocera \n Common Name: "Orange Pore Fungus"',
'6-- Ganoderma lucidum \n Common Name: "Reishi Garnished Conk Mushroom"',
'7-- Laetiporus sulphureus \n Common Name: "Chicken of The Woods Mushroom"',
'8-- Morchella esculenta \n Common Name: "Morel Mushroom"',
'9-- Phallus indusiatus \n Common Name: "Bridal Veil Stinkhorn Mushroom"'
]
# Number of images to display in the grid
num_images = len(image_paths)
# Show the plot
print('THE NINE MUSHROOM SPECIES:\n')
# Create a figure and axes
plt.figure(figsize=(15, 15))
for i in range(0, num_images):
plt.subplot(3, 3, i + 1)
# Load image using PIL
image_pil = Image.open(image_paths[i])
# Convert PIL image to NumPy array
image_np = np.array(image_pil)
# Display image
plt.imshow(image_np)
plt.title(f'{labels[i]}', fontsize=15)
plt.axis('off')
plt.tight_layout()
plt.show();
THE NINE MUSHROOM SPECIES:
For each mushroom species 25 high quality images were collected using the species name in a standard internet image search. These images where then resized to 300 x 300, renamed according to the mushroom genus and species with numerical tags (e.g. amanita_muscaria_001.jpeg, amanita_muscaria_002.jpeg, ... amanita_muscaria_025.jpeg), and finally loaded into the test folder under the fungi_dataset folder. The train folder under the fungi_dataset folder contains the video extracted training images for each mushroom species. Extracted frames were also resized to 300 x 300 but left with their default names (e.g. frame_0001.jpeg, frame_0002.jpeg, ... ) The overall directory layout for the test and train dataset folders is shown below.
# Load image using PIL
image_pil = Image.open(os.path.join(location_of_this_ipynb_file, 'Pics/fungi_dataset_directory_map.jpeg'))
# Convert PIL image to NumPy array
image_np = np.array(image_pil)
# Display image
plt.figure(figsize=(10, 10))
plt.axis('off')
plt.title("Figure 1. Directory Layout for Test and Train Dataset Folders ");
plt.imshow(image_np);
Full Directory Contents: (including fungi_dataset folder)
This Jupyter Notebook file:
-- PyTorch_Mushroom_Image_Classification.ipynb
Image folders:
-- fungi_dataset folder (holding the train and test data set folders)
-- Pics folder (holding the images used in the ipynb file)
Video files (nine .mp4 video montage clips from which training data set images will be extracted):
-- amanita_muscaria.mp4
-- calocera_viscosa.mp4
-- clathrus_ruber.mp4
-- coprinus_comatus.mp4
-- favolaschia_calocera.mp4
-- ganoderma_lucidum.mp4
-- laetiporus_sulphureus.mp4
All work was done on a linux Ubuntu 22.04.3 LTS operating system.
The project began with a search of YouTube for videos of each mushroom species. Once several suitable videos were located, an open source, linux based video screen capture tool, SimpleScreenRecorder version 0.3.11 (shown below in Figures 1 and 2), was used to capture the relevant sections of each video. The SimpleScreenRecorder software allows for selection of specific sections of the screen and output file format (here, .mp4 format was used).
The next step was selecting approximately 10 to 12 highly relevant one second sections from each of the videos screen recorded for each species. These one second clips were then merged into .mp4 files that would later provide the frames used as training set image data. This was done using an open source, linux based video editing tool, Kdenlive version 21.12.3 (shown below in Figure 3). This led to the .mp4 videos listed in Table 1.
plt.figure(figsize=(10, 10))
image_pil_1 = Image.open(os.path.join(location_of_this_ipynb_file, 'Pics/simple_screen_recorder.jpeg'))
image_1 = np.array(image_pil_1)
image_pil_2 = Image.open(os.path.join(location_of_this_ipynb_file, 'Pics/simple_screen_recorder2.jpeg'))
image_2 = np.array(image_pil_2)
plt.subplot(1, 2, 1);
plt.imshow(image_1);
plt.title("Figure 2. Simple Screen Recorder");
plt.axis('off');
plt.subplot(1, 2, 2);
plt.imshow(image_2);
plt.title("Figure 3. Simple Screen Recorder Capturing Screenshot");
plt.axis('off');
plt.figure(figsize=(10, 10))
image_pil_3 = Image.open(os.path.join(location_of_this_ipynb_file, 'Pics/kdenlive.jpeg'))
image_3 = np.array(image_pil_3)
plt.imshow(image_3);
plt.title("Figure 4. Kdenlive Creatign Video Montage Clip");
plt.axis('off');
The function "video_frame_extract_to_train()" is sent a montaged video file and the destination folder to hold the extracted frames from the video. The "video_frame_extract_to_train()" function also calls the "image_transformer()" function to provide image transformations. Frames extracted from each montage video are transformed by the "image_transformer()" function in groups of six as shown below before being saved into their corresponding training set folder under directory fungi_dataset/train/.
Transforms performed by "image_transformer()" function in groups of six:
First frame of six frame group => transform = resize image
Second frame of six frame group => transform = resize image, rotate = + 30 deg
Third frame of six frame group => transform = resize image, rotate = - 30 deg
Fourth frame of six frame group => transform = resize image, flip horizontally
Fifth frame of six frame group => transform = resize image, flip horizontally, rotate = + 30 deg
Sixth frame of six frame group => transform = resize image, flip horizontally, rotate = - 30 deg
The frames extracted from each montaged video clip through the combined work of the "video_frame_extract_to_train()" and "image_transformer()" functions is summarized in the Table 1 below
plt.figure(figsize=(10, 10))
image_pil = Image.open(os.path.join(location_of_this_ipynb_file, 'Pics/montage_videos_table.jpeg'))
image = np.array(image_pil)
plt.imshow(image);
plt.title("Table 1. Extraction of Montage Training Video Frames into Transformed Training Images");
plt.axis('off');
# Function to transform frames in waves or groups of six
def image_transformer(input_image, six_counter):
# Perform various image transforms depending on value of six_counter
# then return transformed image
resized_image = cv2.resize(input_image, (300, 300))
resized_flipped_image = cv2.flip(resized_image, 1)
if six_counter == 1: # transform = resize image
return resized_image
if six_counter == 2: # transform = resize image, rotate = + 30 deg
rotation_matrix = cv2.getRotationMatrix2D((300 / 2, 300 / 2), 30, 1)
resized_pos30image = cv2.warpAffine(resized_image, rotation_matrix, (300, 300))
return resized_pos30image
if six_counter == 3: # transform = resize image, rotate = - 30 deg
rotation_matrix = cv2.getRotationMatrix2D((300 / 2, 300 / 2), -30, 1)
resized_neg30image = cv2.warpAffine(resized_image, rotation_matrix, (300, 300))
return resized_neg30image
if six_counter == 4: # transform = resize image, flip horizontally
return resized_flipped_image
if six_counter == 5: # transform = resize image, flip horizontally, rotate = + 30 deg
rotation_matrix = cv2.getRotationMatrix2D((300 / 2, 300 / 2), 30, 1)
resized_flipped_pos30image = cv2.warpAffine(resized_flipped_image, rotation_matrix, (300, 300))
return resized_flipped_pos30image
if six_counter == 6: # transform = resize image, flip horizontally, rotate = - 30 deg
rotation_matrix = cv2.getRotationMatrix2D((300 / 2, 300 / 2), -30, 1)
resized_flipped_neg30image = cv2.warpAffine(resized_flipped_image, rotation_matrix, (300, 300))
return resized_flipped_neg30image
# Function to:
# --- 1.) Read each frame
# --- 2.) Send each frame for transformation
# --- 3.) Save each transformed frame as a JPEG image into 'training_folder_path'
def video_frame_extract_to_train (video_file_path, training_folder_path):
# Create the output folder if it doesn't exist
os.makedirs(training_folder_path, exist_ok=True)
six_counter = 1;
# Create a video_reader object
cap = cv2.VideoCapture(video_file_path)
# Get information about the video
fps = cap.get(cv2.CAP_PROP_FPS)
num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
# --- 1.) Read each frame
# --- 2.) Send frame for transformation
# --- 3.) Save each current frame as a JPEG image into 'training_folder_path'
# Loop through each frame
for frame_index in range(num_frames):
# --- 1.) Read each frame
# Define video capture object
ret, current_frame = cap.read()
if not ret:
break
# --- 2.) Send frame for transformation with six_counter to determine
# type of transformation
transformed_frame = image_transformer(current_frame, six_counter)
# --- 3.) Save each transformed frame as a JPEG image into appropriate training folder
num_id = str(frame_index + 1).zfill(3)
file_type = '.jpeg'
cv2.imwrite(os.path.join( training_folder_path + "frame_" + num_id + file_type ), transformed_frame)
# update six_counter variable for next round
six_counter = six_counter + 1;
# reset six_counter variable after every 6 image transformations
if six_counter > 6:
six_counter = 1
# Release the video capture object
cap.release()
# List of montage videos
montage_video_list =['amanita_muscaria.mp4','calocera_viscosa.mp4','clathrus_ruber.mp4',
'coprinus_comatus.mp4','favolaschia_calocera.mp4','ganoderma_lucidum.mp4',
'laetiporus_sulphureus.mp4','morchella_esculenta.mp4','phallus_indusiatus.mp4']
# List of training folders
training_folder_list=['fungi_dataset/train/amanita_muscaria/','fungi_dataset/train/calocera_viscosa/',
'fungi_dataset/train/clathrus_ruber/','fungi_dataset/train/coprinus_comatus/',
'fungi_dataset/train/favolaschia_calocera/','fungi_dataset/train/ganoderma_lucidum/',
'fungi_dataset/train/laetiporus_sulphureus/','fungi_dataset/train/morchella_esculenta/',
'fungi_dataset/train/phallus_indusiatus/']
# Use both lists and functions to construct training image data sets
for video, folder in zip(montage_video_list, training_folder_list):
video_path = os.path.join(location_of_this_ipynb_file, video)
folder_path = os.path.join(location_of_this_ipynb_file, folder)
video_frame_extract_to_train(video_path, folder_path);
The training image set folders now contain the numbers of images shown in the rightmost column of Table 1.
# Function to display image samples of test and train sets for each mushroom species
def display_random_image_sets(folder_species, folder_everyday_name, num_images=4):
print('======================================================================================================')
print('======================================================================================================')
folder_path = os.path.join(location_of_this_ipynb_file,
'fungi_dataset/test/'+str(folder_species)+'/')
# Get a list of all image files in the folder
image_files = [f for f in os.listdir(folder_path) if f.lower().endswith(('jpeg'))]
# Randomly select num_images from the list
selected_images = random.sample(image_files, num_images)
# Display the selected images
fig, axes = plt.subplots(1, num_images, figsize=(12, 3))
print('TEST SET SAMPLE IMAGES FOR ', str(folder_species), '\t', str(folder_everyday_name))
for i, image_file in enumerate(selected_images):
image_path = os.path.join(folder_path, image_file)
img = Image.open(image_path)
axes[i].imshow(img)
axes[i].axis('off')
plt.show()
folder_path = os.path.join(location_of_this_ipynb_file,
'fungi_dataset/train/'+str(folder_species)+'/')
# Get a list of all image files in the folder
image_files = [f for f in os.listdir(folder_path) if f.lower().endswith(('jpeg'))]
# Randomly select num_images from the list
selected_images = random.sample(image_files, num_images)
# Display the selected images
fig, axes = plt.subplots(1, num_images, figsize=(12, 3))
print('TRAIN SET SAMPLE IMAGES FOR ', str(folder_species), '\t', str(folder_everyday_name))
for i, image_file in enumerate(selected_images):
image_path = os.path.join(folder_path, image_file)
img = Image.open(image_path)
axes[i].imshow(img)
axes[i].axis('off')
plt.show()
species_names = ['amanita_muscaria',
'calocera_viscosa',
'clathrus_ruber',
'coprinus_comatus',
'favolaschia_calocera',
'ganoderma_lucidum',
'laetiporus_sulphureus',
'morchella_esculenta',
'phallus_indusiatus']
common_names = ['Common Name: "Fly Agaric Mushroom"',
'Common Name: "Yellow Staghorn Mushroom"',
'Common Name: "Red Cage Lattice Stinkhorn"',
'Common Name: "Shaggy Ink Cap Mushroom"',
'Common Name: "Orange Pore Fungus"',
'Common Name: "Reishi Garnished Conk Mushroom"',
'Common Name: "Chicken of The Woods Mushroom"',
'Common Name: "Morel Mushroom"',
'Common Name: "Bridal Veil Stinkhorn Mushroom"'
]
for species, everyday_name in zip(species_names, common_names):
display_random_image_sets(species, everyday_name)
====================================================================================================== ====================================================================================================== TEST SET SAMPLE IMAGES FOR amanita_muscaria Common Name: "Fly Agaric Mushroom"
TRAIN SET SAMPLE IMAGES FOR amanita_muscaria Common Name: "Fly Agaric Mushroom"
====================================================================================================== ====================================================================================================== TEST SET SAMPLE IMAGES FOR calocera_viscosa Common Name: "Yellow Staghorn Mushroom"
TRAIN SET SAMPLE IMAGES FOR calocera_viscosa Common Name: "Yellow Staghorn Mushroom"
====================================================================================================== ====================================================================================================== TEST SET SAMPLE IMAGES FOR clathrus_ruber Common Name: "Red Cage Lattice Stinkhorn"
TRAIN SET SAMPLE IMAGES FOR clathrus_ruber Common Name: "Red Cage Lattice Stinkhorn"
====================================================================================================== ====================================================================================================== TEST SET SAMPLE IMAGES FOR coprinus_comatus Common Name: "Shaggy Ink Cap Mushroom"
TRAIN SET SAMPLE IMAGES FOR coprinus_comatus Common Name: "Shaggy Ink Cap Mushroom"
====================================================================================================== ====================================================================================================== TEST SET SAMPLE IMAGES FOR favolaschia_calocera Common Name: "Orange Pore Fungus"
TRAIN SET SAMPLE IMAGES FOR favolaschia_calocera Common Name: "Orange Pore Fungus"
====================================================================================================== ====================================================================================================== TEST SET SAMPLE IMAGES FOR ganoderma_lucidum Common Name: "Reishi Garnished Conk Mushroom"
TRAIN SET SAMPLE IMAGES FOR ganoderma_lucidum Common Name: "Reishi Garnished Conk Mushroom"
====================================================================================================== ====================================================================================================== TEST SET SAMPLE IMAGES FOR laetiporus_sulphureus Common Name: "Chicken of The Woods Mushroom"
TRAIN SET SAMPLE IMAGES FOR laetiporus_sulphureus Common Name: "Chicken of The Woods Mushroom"
====================================================================================================== ====================================================================================================== TEST SET SAMPLE IMAGES FOR morchella_esculenta Common Name: "Morel Mushroom"
TRAIN SET SAMPLE IMAGES FOR morchella_esculenta Common Name: "Morel Mushroom"
====================================================================================================== ====================================================================================================== TEST SET SAMPLE IMAGES FOR phallus_indusiatus Common Name: "Bridal Veil Stinkhorn Mushroom"
TRAIN SET SAMPLE IMAGES FOR phallus_indusiatus Common Name: "Bridal Veil Stinkhorn Mushroom"
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Define your dataset paths for training and testing
train_data_path = os.path.join(location_of_this_ipynb_file, 'fungi_dataset/train/')
test_data_path = os.path.join(location_of_this_ipynb_file, 'fungi_dataset/test/')
# Define data transformations
data_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# Load datasets
train_dataset = datasets.ImageFolder(train_data_path, transform=data_transform)
test_dataset = datasets.ImageFolder(test_data_path, transform=data_transform)
# Define dataloaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
# Define the ResNet50 model
model = models.resnet50(pretrained=True);
# Replace the final fully connected layer for mushroom classification task
num_ftrs = model.fc.in_features;
model.fc = nn.Linear(num_ftrs, len(class_names));
model.to(device);
# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# Function to train the model
def train_model(model, criterion, optimizer, num_epochs=10):
loss_list = []
accuracy_list = []
for epoch in range(num_epochs):
model.train()
running_loss = 0.0
corrects = 0
base_text = '.'
# Record start time
start_time = time.time()
for inputs, labels in dataloaders['train']:
print(base_text, end='')
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item() * inputs.size(0)
_, preds = torch.max(outputs, 1)
corrects += torch.sum(preds == labels.data)
# Record end time
end_time = time.time()
# Calculate elapsed time
elapsed_time = end_time - start_time
epoch_loss = running_loss / dataset_sizes['train']
epoch_acc = corrects.double() / dataset_sizes['train']
print(f'\nEpoch {epoch + 1}/{num_epochs}'
f' Loss: {epoch_loss:.4f}'
f' Acc: {epoch_acc:.4f}'
f' Epoch Duration: {elapsed_time:.0f} seconds')
loss_list.append(epoch_loss)
accuracy_list.append(epoch_acc)
return model, loss_list, accuracy_list
# Train the model
trained_model, loss_list, accuracy_list = train_model(model, criterion, optimizer, num_epochs=10)
# silence warnings
import warnings
warnings.filterwarnings('ignore')
................................................................................................ Epoch 1/10 Loss: 0.7436 Acc: 0.8454 Epoch Duration: 1834.5162143707275 seconds ................................................................................................ Epoch 2/10 Loss: 0.0963 Acc: 0.9813 Epoch Duration: 2089.622031211853 seconds ................................................................................................ Epoch 3/10 Loss: 0.0667 Acc: 0.9859 Epoch Duration: 2043.472440481186 seconds ................................................................................................ Epoch 4/10 Loss: 0.0534 Acc: 0.9872 Epoch Duration: 2234.8858783245087 seconds ................................................................................................ Epoch 5/10 Loss: 0.0374 Acc: 0.9918 Epoch Duration: 1981.3628253936768 seconds ................................................................................................ Epoch 6/10 Loss: 0.0282 Acc: 0.9941 Epoch Duration: 2011.7476074695587 seconds ................................................................................................ Epoch 7/10 Loss: 0.0300 Acc: 0.9928 Epoch Duration: 2421.4473345279694 seconds ................................................................................................ Epoch 8/10 Loss: 0.0275 Acc: 0.9924 Epoch Duration: 2390.083594560623 seconds ................................................................................................ Epoch 9/10 Loss: 0.0209 Acc: 0.9954 Epoch Duration: 2287.932721853256 seconds ................................................................................................ Epoch 10/10 Loss: 0.0180 Acc: 0.9957 Epoch Duration: 1763.1282176971436 seconds
# Plot the training set loss
plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.plot(loss_list, label='Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.legend()
# Plot the training set accuracy
plt.subplot(1, 2, 2)
plt.plot(accuracy_list, label='Training Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Training Accuracy')
plt.legend()
plt.tight_layout();
plt.show();
class_names = ['amanita_muscaria',
'calocera_viscosa',
'clathrus_ruber',
'coprinus_comatus',
'favolaschia_calocera',
'ganoderma_lucidum',
'laetiporus_sulphureus',
'morchella_esculenta',
'phallus_indusiatus']
# Function to evaluate the model on the test set
def evaluate_model(model):
model.eval()
running_loss = 0.0
corrects = 0
all_labels = []
all_preds = []
image_paths = []
misclassified_image_paths = []
misclassified_images = []
with torch.no_grad():
for inputs, labels in dataloaders['test']:
outputs = model(inputs)
loss = criterion(outputs, labels)
running_loss += loss.item() * inputs.size(0)
_, preds = torch.max(outputs, 1)
corrects += torch.sum(preds == labels.data)
all_labels.extend(labels.numpy())
all_preds.extend(preds.numpy())
# Iterate over each image in the batch
for i in range(inputs.size(0)):
# Get the true and predicted labels for the current image
true_label_index = labels[i].item()
predicted_label_index = preds[i].item()
# Collect misclassified images as tensors from inputs[i]
if true_label_index != predicted_label_index:
misclassified_images.append(inputs[i])
# Plot the image and show index
# print('Index:', i)
# plt.imshow(inputs[i].permute(1, 2, 0).numpy());
# plt.title(f'True: {class_names[true_label_index]} \nPredicted: {class_names[predicted_label_index]}')
# plt.show()
test_loss = running_loss / dataset_sizes['test']
test_acc = corrects.double() / dataset_sizes['test']
print(f'Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}')
return all_labels, all_preds, misclassified_images
# Evaluate the model
true_labels, predicted_labels, mislabeled_images = evaluate_model(trained_model)
Test Loss: 0.3864, Test Acc: 0.8756
# Function to plot confusion matrix
def plot_confusion_matrix(true_labels, predicted_labels, class_names):
cm = confusion_matrix(true_labels, predicted_labels)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
plt.title('Confusion Matrix')
plt.xlabel('Predicted Labels')
plt.ylabel('True Labels')
plt.show()
# Plot confusion matrix
plot_confusion_matrix(true_labels, predicted_labels, class_names)
accuracy = str(round((1- (len(mislabeled_images)/len(predicted_labels)))*100, 2)) + ' %'
# Organize Data
data = {'Model': ['ResNet50'],
'Predictions': [len(predicted_labels)],
'Errors': [len(mislabeled_images)],
'Accuracy': [accuracy]}
# Construct a DataFrame
df = pd.DataFrame.from_dict(data).set_index('Model')
# Display the DataFrame
print(df);
Predictions Errors Accuracy Model ResNet50 225 28 87.56 %
# Display all mislabeled images with class names for model with best performance ("General Model CFV")
mislabeled_indices = np.where(np.array(true_labels) != np.array(predicted_labels))[0];
plt.figure(figsize=(30,30))
for i in range(0, len(mislabeled_indices)):
index = mislabeled_indices[i]
true_label_index = true_labels[index]
predicted_label_index = predicted_labels[index]
image_np = mislabeled_images[i].permute(1, 2, 0).numpy()
image_np = (image_np - image_np.min()) / (image_np.max() - image_np.min())
image_pil = Image.fromarray((image_np * 255).astype('uint8'))
plt.subplot(8, 4, i + 1)
plt.imshow(image_pil)
plt.title(f'True: {class_names[true_label_index]} \n'
f' Predicted: {class_names[predicted_label_index]}', fontsize=30)
plt.axis('off')
plt.tight_layout()
plt.show();
OBSERVATIONS:
The model had the most difficulty with:
1.) Clathrus ruber (Red Cage Lattice Stinkhorn Mushroom)
2.) Coprinus comatus (Shaggy Ink Cap Mushroom)
3.) Morchella esculenta (Morel Mushroom)
Clathrus ruber was commonly mistaken for Favolaschia calocera (Yellow Staghorn Mushroom) with 4 errors and
Morchella esculenta (Morel Mushroom) with 3 errors. All three species share a fenestrated, spindly appearance on some sections of their structures.
Difficulties classifying Coprinus comatus are to be expected. The video used to provide training data images, "coprinus_comatus.mp4", shows the dramatic changes in appearance the species undergoes during growth. This may make Coprinus comatus at times appear like another species. In this case, the model misclassified it as Phallus indusiatus (Bridal Veil Stinkhorn Mushroom) in 6 cases. This seems reasonable since both species share an elongated, torpedo-like shape.
The wrinkled cap of Morchella esculenta may have been a factor in its misclassification as Favolaschia calocera 3 times and Phallus indusiatus in 2 cases.
# Show the plot
print('THE NINE MUSHROOM SPECIES:\n')
# Create a figure and axes
plt.figure(figsize=(15, 15))
for i in range(0, num_images):
plt.subplot(3, 3, i + 1)
# Load image using PIL
image_pil = Image.open(image_paths[i])
# Convert PIL image to NumPy array
image_np = np.array(image_pil)
# Display image
plt.imshow(image_np)
plt.title(f'{labels[i]}', fontsize=15)
plt.axis('off')
plt.tight_layout()
plt.show();
THE NINE MUSHROOM SPECIES: